{ "cells": [ { "cell_type": "markdown", "id": "959034ec2fbaf137", "metadata": { "collapsed": false, "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "# Fairness Metrics\n", "\n", "A key insight that makes this library so flexible with respect to how fairness is defined, is that many fairness definitions simply compare a statistic between groups.\n", "\n", "In this notebook, we express many well-known fairness definitions in terms of statistics through the use of the `LinearFractionalParity` class. The class is implemented as a `torchmetrics` Metric, which allows it to be integrated into any training or evaluation loop in Pytorch. All that it needs is a statistic to compare.\n", "\n", "To start, make sure the library is installed:" ] }, { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true }, "ExecuteTime": { "end_time": "2024-07-22T15:17:06.151669900Z", "start_time": "2024-07-22T15:17:04.403287400Z" } }, "outputs": [], "source": [ "import torchmetrics" ] }, { "cell_type": "markdown", "id": "fd72bed4c7aa61cf", "metadata": { "collapsed": false }, "source": [ "If it isn't, you can install it with:\n", "\n", "``pip install torchmetrics``" ] }, { "cell_type": "code", "execution_count": 2, "id": "951cb1f79772f847", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:10.369985500Z", "start_time": "2024-04-26T15:32:10.361621Z" } }, "outputs": [], "source": [ "from fairret.metric import LinearFractionalParity" ] }, { "cell_type": "markdown", "id": "b2a8b011bcdada61", "metadata": { "collapsed": false }, "source": [ "The most simple statistic to compare is the positive rate, i.e. the rate at which positive predictions are made for each sensitive feature. Equality in these positive rates is typically referred to as *demographic parity* (a.k.a. *statistical parity* or *equal acceptance rate*).\n", "\n", "Hence, we can express the extent to which demographic parity holds by passing the `PositiveRate` statistic to a `LinearFractionalParity` metric:" ] }, { "cell_type": "code", "execution_count": 3, "id": "df85652916f4ba1b", "metadata": { "collapsed": false, "editable": true, "slideshow": { "slide_type": "" }, "tags": [ "raises-exception" ], "ExecuteTime": { "end_time": "2024-04-26T15:32:10.567032800Z", "start_time": "2024-04-26T15:32:10.371984100Z" } }, "outputs": [ { "ename": "TypeError", "evalue": "LinearFractionalParity.__init__() missing 1 required positional argument: 'stat_shape'", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[1;31mTypeError\u001B[0m Traceback (most recent call last)", "Cell \u001B[1;32mIn[3], line 2\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mfairret\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mstatistic\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m PositiveRate\n\u001B[1;32m----> 2\u001B[0m demographic_parity \u001B[38;5;241m=\u001B[39m \u001B[43mLinearFractionalParity\u001B[49m\u001B[43m(\u001B[49m\u001B[43mPositiveRate\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n", "\u001B[1;31mTypeError\u001B[0m: LinearFractionalParity.__init__() missing 1 required positional argument: 'stat_shape'" ] } ], "source": [ "from fairret.statistic import PositiveRate\n", "demographic_parity = LinearFractionalParity(PositiveRate())" ] }, { "cell_type": "markdown", "id": "11f51b0718c0cc6f", "metadata": { "collapsed": false }, "source": [ "Oops, we also need to provide the shape of the statistic. Most often, this will simply be the number of sensitive features. For example, the `PositiveRate` statistic just computes a single value for each sensitive feature. \n", "\n", "To know the number of sensitive features, we need to define some data first:" ] }, { "cell_type": "code", "execution_count": 4, "id": "5fdb88f4a060c4af", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:16.300376300Z", "start_time": "2024-04-26T15:32:16.291743600Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of sensitive features: 2\n" ] } ], "source": [ "import torch\n", "torch.manual_seed(0)\n", "\n", "feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])\n", "sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])\n", "label = torch.tensor([[0.], [1.], [0.], [1.]])\n", "\n", "n_sensitive_features = sens.shape[1]\n", "print(f\"Number of sensitive features: {n_sensitive_features}\")" ] }, { "cell_type": "markdown", "source": [ "We can now construct the metric:" ], "metadata": { "collapsed": false }, "id": "a1a70a82c8ec2d31" }, { "cell_type": "code", "execution_count": 5, "id": "e2eeeea3657e4928", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:16.849220300Z", "start_time": "2024-04-26T15:32:16.840157200Z" } }, "outputs": [], "source": [ "demographic_parity = LinearFractionalParity(PositiveRate(), stat_shape=(n_sensitive_features,))" ] }, { "cell_type": "markdown", "source": [ "`LinearFractionalParity` follows the exact interface as all other `Metric` classes in `torchmetrics`.\n", "\n", "For all details of this interface, check out the `torchmetrics` [documentation](https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html#module-metrics).\n", "\n", "Basically, it follows a three-step approach:\n", "1. Call `metric.update(args)`, where `args` in our case are the arguments necessary to compute the statistic.\n", "2. Call `metric.compute()`, which returns violation of the fairness definition.\n", "3. Call `metric.reset()`, which resets the initial state of the metric.\n", "\n", "If you need to use any other `torchmetrics` settings, such as `compute_with_cache`, you can pass them as keyword arguments to the `LinearFractionalParity` class upon initialization." ], "metadata": { "collapsed": false }, "id": "c642ccb43bc88cd1" }, { "cell_type": "markdown", "id": "cda582927145b85c", "metadata": { "collapsed": false }, "source": [ "## Example\n", "\n", "Let's train a model and keep track of the demographic parity.\n", "\n", "Without fairret:" ] }, { "cell_type": "code", "execution_count": 6, "id": "2354981e1eb0c4d4", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:17.376258200Z", "start_time": "2024-04-26T15:32:17.365862100Z" } }, "outputs": [], "source": [ "h_layer_dim = 16\n", "lr = 1e-3\n", "batch_size = 1024\n", "\n", "torch.manual_seed(0)\n", "\n", "def build_model():\n", " model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", " )\n", " optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", " return model, optimizer\n", "\n", "from torch.utils.data import TensorDataset, DataLoader\n", "dataset = TensorDataset(feat, sens, label)\n", "dataloader = DataLoader(dataset, batch_size=batch_size)" ] }, { "cell_type": "code", "execution_count": 7, "id": "b497337ee1e2f46c", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:17.907945900Z", "start_time": "2024-04-26T15:32:17.537121100Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.7091795206069946, dp: 0.03126168251037598\n", "Epoch: 1, loss: 0.7061765193939209, dp: 0.02563762664794922\n", "Epoch: 2, loss: 0.7033581733703613, dp: 0.020147204399108887\n", "Epoch: 3, loss: 0.7007156610488892, dp: 0.014800786972045898\n", "Epoch: 4, loss: 0.6982340812683105, dp: 0.009598910808563232\n", "Epoch: 5, loss: 0.6959078907966614, dp: 0.00453948974609375\n", "Epoch: 6, loss: 0.6937355995178223, dp: 0.00037485361099243164\n", "Epoch: 7, loss: 0.6917158365249634, dp: 0.005139470100402832\n", "Epoch: 8, loss: 0.6898466944694519, dp: 0.009749293327331543\n", "Epoch: 9, loss: 0.6881252527236938, dp: 0.014199256896972656\n", "Epoch: 10, loss: 0.6865478754043579, dp: 0.01848423480987549\n", "Epoch: 11, loss: 0.6851094961166382, dp: 0.022599458694458008\n", "Epoch: 12, loss: 0.6838041543960571, dp: 0.02654099464416504\n", "Epoch: 13, loss: 0.6826250553131104, dp: 0.030305147171020508\n", "Epoch: 14, loss: 0.6815641522407532, dp: 0.03388887643814087\n", "Epoch: 15, loss: 0.6806124448776245, dp: 0.037290215492248535\n", "Epoch: 16, loss: 0.6797604560852051, dp: 0.040507614612579346\n", "Epoch: 17, loss: 0.6789975762367249, dp: 0.04354041814804077\n", "Epoch: 18, loss: 0.6783132553100586, dp: 0.04638934135437012\n", "Epoch: 19, loss: 0.6776963472366333, dp: 0.04905581474304199\n", "Epoch: 20, loss: 0.6771360039710999, dp: 0.051542043685913086\n", "Epoch: 21, loss: 0.6766215562820435, dp: 0.05385148525238037\n", "Epoch: 22, loss: 0.6761429309844971, dp: 0.055988192558288574\n", "Epoch: 23, loss: 0.6756909489631653, dp: 0.05795770883560181\n", "Epoch: 24, loss: 0.6752569675445557, dp: 0.05976557731628418\n", "Epoch: 25, loss: 0.6748337745666504, dp: 0.06141817569732666\n", "Epoch: 26, loss: 0.674415111541748, dp: 0.06292301416397095\n", "Epoch: 27, loss: 0.673996090888977, dp: 0.06428724527359009\n", "Epoch: 28, loss: 0.6735726594924927, dp: 0.06551891565322876\n", "Epoch: 29, loss: 0.6731564998626709, dp: 0.06661409139633179\n", "Epoch: 30, loss: 0.6727579236030579, dp: 0.06756091117858887\n", "Epoch: 31, loss: 0.672345757484436, dp: 0.06839364767074585\n", "Epoch: 32, loss: 0.6719199419021606, dp: 0.0691213607788086\n", "Epoch: 33, loss: 0.6714813709259033, dp: 0.06975340843200684\n", "Epoch: 34, loss: 0.6710319519042969, dp: 0.0702979564666748\n", "Epoch: 35, loss: 0.6705741882324219, dp: 0.07076394557952881\n", "Epoch: 36, loss: 0.6701083779335022, dp: 0.07116150856018066\n", "Epoch: 37, loss: 0.669636607170105, dp: 0.07149994373321533\n", "Epoch: 38, loss: 0.6691610217094421, dp: 0.07178771495819092\n", "Epoch: 39, loss: 0.6686834096908569, dp: 0.07203388214111328\n", "Epoch: 40, loss: 0.6682056188583374, dp: 0.0722469687461853\n", "Epoch: 41, loss: 0.6677289009094238, dp: 0.07243525981903076\n", "Epoch: 42, loss: 0.667254626750946, dp: 0.07260710000991821\n", "Epoch: 43, loss: 0.6667835712432861, dp: 0.07276999950408936\n", "Epoch: 44, loss: 0.6663164496421814, dp: 0.07293164730072021\n", "Epoch: 45, loss: 0.6658533811569214, dp: 0.07309883832931519\n", "Epoch: 46, loss: 0.6653945446014404, dp: 0.07327830791473389\n", "Epoch: 47, loss: 0.6649397015571594, dp: 0.0734756588935852\n", "Epoch: 48, loss: 0.6644884347915649, dp: 0.07369661331176758\n", "Epoch: 49, loss: 0.6640403270721436, dp: 0.07394576072692871\n", "Epoch: 50, loss: 0.6635947227478027, dp: 0.07422733306884766\n", "Epoch: 51, loss: 0.6631510257720947, dp: 0.07454437017440796\n", "Epoch: 52, loss: 0.6628453135490417, dp: 0.07504117488861084\n", "Epoch: 53, loss: 0.6625917553901672, dp: 0.07558834552764893\n", "Epoch: 54, loss: 0.6623181104660034, dp: 0.07614338397979736\n", "Epoch: 55, loss: 0.6620256900787354, dp: 0.07671010494232178\n", "Epoch: 56, loss: 0.6617173552513123, dp: 0.07728838920593262\n", "Epoch: 57, loss: 0.6614043116569519, dp: 0.07785654067993164\n", "Epoch: 58, loss: 0.6610796451568604, dp: 0.07842350006103516\n", "Epoch: 59, loss: 0.6607442498207092, dp: 0.07899010181427002\n", "Epoch: 60, loss: 0.6603990793228149, dp: 0.07955652475357056\n", "Epoch: 61, loss: 0.6600450277328491, dp: 0.08012241125106812\n", "Epoch: 62, loss: 0.6596829295158386, dp: 0.08068704605102539\n", "Epoch: 63, loss: 0.6593135595321655, dp: 0.08124935626983643\n", "Epoch: 64, loss: 0.6589376330375671, dp: 0.08180773258209229\n", "Epoch: 65, loss: 0.6585558652877808, dp: 0.08236086368560791\n", "Epoch: 66, loss: 0.6581688523292542, dp: 0.08290684223175049\n", "Epoch: 67, loss: 0.6577771306037903, dp: 0.08344399929046631\n", "Epoch: 68, loss: 0.6574320793151855, dp: 0.08401530981063843\n", "Epoch: 69, loss: 0.6571431756019592, dp: 0.08470660448074341\n", "Epoch: 70, loss: 0.6568371653556824, dp: 0.08542871475219727\n", "Epoch: 71, loss: 0.6565203666687012, dp: 0.08615553379058838\n", "Epoch: 72, loss: 0.6561905145645142, dp: 0.08688771724700928\n", "Epoch: 73, loss: 0.6558488607406616, dp: 0.08762180805206299\n", "Epoch: 74, loss: 0.65549635887146, dp: 0.08835643529891968\n", "Epoch: 75, loss: 0.6551340818405151, dp: 0.08908939361572266\n", "Epoch: 76, loss: 0.6547629237174988, dp: 0.08981943130493164\n", "Epoch: 77, loss: 0.6544535160064697, dp: 0.09061932563781738\n", "Epoch: 78, loss: 0.6541627645492554, dp: 0.09137928485870361\n", "Epoch: 79, loss: 0.6538523435592651, dp: 0.09210765361785889\n", "Epoch: 80, loss: 0.6535260677337646, dp: 0.09280455112457275\n", "Epoch: 81, loss: 0.6531944274902344, dp: 0.09344327449798584\n", "Epoch: 82, loss: 0.6528521776199341, dp: 0.09403371810913086\n", "Epoch: 83, loss: 0.6525000333786011, dp: 0.09457921981811523\n", "Epoch: 84, loss: 0.652138888835907, dp: 0.09508275985717773\n", "Epoch: 85, loss: 0.6518597602844238, dp: 0.09564340114593506\n", "Epoch: 86, loss: 0.6515651345252991, dp: 0.09620797634124756\n", "Epoch: 87, loss: 0.6512539982795715, dp: 0.09679919481277466\n", "Epoch: 88, loss: 0.6509299874305725, dp: 0.09739184379577637\n", "Epoch: 89, loss: 0.650594174861908, dp: 0.09798645973205566\n", "Epoch: 90, loss: 0.6502466797828674, dp: 0.09860563278198242\n", "Epoch: 91, loss: 0.6498894691467285, dp: 0.09922534227371216\n", "Epoch: 92, loss: 0.6495950818061829, dp: 0.09992420673370361\n", "Epoch: 93, loss: 0.6493034362792969, dp: 0.10058420896530151\n", "Epoch: 94, loss: 0.6489962339401245, dp: 0.10119062662124634\n", "Epoch: 95, loss: 0.6486777067184448, dp: 0.10174638032913208\n", "Epoch: 96, loss: 0.6483432650566101, dp: 0.10228461027145386\n", "Epoch: 97, loss: 0.6479994058609009, dp: 0.10277950763702393\n", "Epoch: 98, loss: 0.6476455330848694, dp: 0.10323858261108398\n", "Epoch: 99, loss: 0.6473514437675476, dp: 0.10378742218017578\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 100\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss.backward()\n", " \n", " pred = torch.sigmoid(logit)\n", " demographic_parity.update(pred, batch_sens)\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " dp_for_epoch = demographic_parity.compute()\n", " demographic_parity.reset()\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}\")" ] }, { "cell_type": "markdown", "id": "dd9623e696885853", "metadata": { "collapsed": false }, "source": [ "With fairret:" ] }, { "cell_type": "code", "execution_count": 8, "id": "d015f2fdadada466", "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-26T15:32:18.345179600Z", "start_time": "2024-04-26T15:32:17.937616500Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.7429860830307007, dp: 0.0319594144821167\n", "Epoch: 1, loss: 0.7368547916412354, dp: 0.02823483943939209\n", "Epoch: 2, loss: 0.7306909561157227, dp: 0.024389266967773438\n", "Epoch: 3, loss: 0.7244950532913208, dp: 0.020473718643188477\n", "Epoch: 4, loss: 0.7182670831680298, dp: 0.016491293907165527\n", "Epoch: 5, loss: 0.7120081186294556, dp: 0.012444257736206055\n", "Epoch: 6, loss: 0.7057176828384399, dp: 0.008333325386047363\n", "Epoch: 7, loss: 0.6993964910507202, dp: 0.004159450531005859\n", "Epoch: 8, loss: 0.6933507323265076, dp: 7.653236389160156e-05\n", "Epoch: 9, loss: 0.6987630128860474, dp: 0.0022329092025756836\n", "Epoch: 10, loss: 0.700627326965332, dp: 0.0029762983322143555\n", "Epoch: 11, loss: 0.6999425292015076, dp: 0.0027066469192504883\n", "Epoch: 12, loss: 0.697374701499939, dp: 0.0016865730285644531\n", "Epoch: 13, loss: 0.693393886089325, dp: 9.85860824584961e-05\n", "Epoch: 14, loss: 0.6960413455963135, dp: 0.0019237995147705078\n", "Epoch: 15, loss: 0.6981306076049805, dp: 0.003303050994873047\n", "Epoch: 16, loss: 0.6993909478187561, dp: 0.004129886627197266\n", "Epoch: 17, loss: 0.6999239921569824, dp: 0.004477024078369141\n", "Epoch: 18, loss: 0.6998161673545837, dp: 0.004403471946716309\n", "Epoch: 19, loss: 0.6991404891014099, dp: 0.0039577484130859375\n", "Epoch: 20, loss: 0.697959303855896, dp: 0.0031795501708984375\n", "Epoch: 21, loss: 0.696326732635498, dp: 0.0021026134490966797\n", "Epoch: 22, loss: 0.6942894458770752, dp: 0.0007546544075012207\n", "Epoch: 23, loss: 0.6952502727508545, dp: 0.0008406639099121094\n", "Epoch: 24, loss: 0.6972097754478455, dp: 0.0016245245933532715\n", "Epoch: 25, loss: 0.6974166631698608, dp: 0.0017071962356567383\n", "Epoch: 26, loss: 0.6961178779602051, dp: 0.001187741756439209\n", "Epoch: 27, loss: 0.6935380101203918, dp: 0.00015437602996826172\n", "Epoch: 28, loss: 0.6951382160186768, dp: 0.0013152360916137695\n", "Epoch: 29, loss: 0.6966246366500854, dp: 0.002297043800354004\n", "Epoch: 30, loss: 0.6974535584449768, dp: 0.002843618392944336\n", "Epoch: 31, loss: 0.6976898908615112, dp: 0.0030000805854797363\n", "Epoch: 32, loss: 0.697391152381897, dp: 0.0028044581413269043\n", "Epoch: 33, loss: 0.6966089010238647, dp: 0.0022897720336914062\n", "Epoch: 34, loss: 0.6953877210617065, dp: 0.0014837980270385742\n", "Epoch: 35, loss: 0.6937685608863831, dp: 0.00041115283966064453\n", "Epoch: 36, loss: 0.6954146027565002, dp: 0.000906825065612793\n", "Epoch: 37, loss: 0.6969265937805176, dp: 0.0015110969543457031\n", "Epoch: 38, loss: 0.6968545317649841, dp: 0.001482248306274414\n", "Epoch: 39, loss: 0.6953892707824707, dp: 0.0008966922760009766\n", "Epoch: 40, loss: 0.693409264087677, dp: 0.00017368793487548828\n", "Epoch: 41, loss: 0.6943631172180176, dp: 0.000807642936706543\n", "Epoch: 42, loss: 0.6947269439697266, dp: 0.001049339771270752\n", "Epoch: 43, loss: 0.6945565342903137, dp: 0.0009363889694213867\n", "Epoch: 44, loss: 0.6939018368721008, dp: 0.0005016326904296875\n", "Epoch: 45, loss: 0.6937127113342285, dp: 0.00022614002227783203\n", "Epoch: 46, loss: 0.6939553618431091, dp: 0.00032335519790649414\n", "Epoch: 47, loss: 0.6933559775352478, dp: 0.00013881921768188477\n", "Epoch: 48, loss: 0.6934863924980164, dp: 0.0002257227897644043\n", "Epoch: 49, loss: 0.6932146549224854, dp: 2.6941299438476562e-05\n", "Epoch: 50, loss: 0.6935896873474121, dp: 0.00029468536376953125\n", "Epoch: 51, loss: 0.6935305595397949, dp: 0.00025534629821777344\n", "Epoch: 52, loss: 0.6934249401092529, dp: 0.00011110305786132812\n", "Epoch: 53, loss: 0.6933110952377319, dp: 0.0001093149185180664\n", "Epoch: 54, loss: 0.6932008266448975, dp: 2.1457672119140625e-05\n", "Epoch: 55, loss: 0.6937585473060608, dp: 0.00040733814239501953\n", "Epoch: 56, loss: 0.6938467025756836, dp: 0.0004661083221435547\n", "Epoch: 57, loss: 0.6934322118759155, dp: 0.0001900196075439453\n", "Epoch: 58, loss: 0.6941245198249817, dp: 0.00039076805114746094\n", "Epoch: 59, loss: 0.6940488815307617, dp: 0.0003604888916015625\n", "Epoch: 60, loss: 0.6934659481048584, dp: 0.00021249055862426758\n", "Epoch: 61, loss: 0.6937500834465027, dp: 0.0004018545150756836\n", "Epoch: 62, loss: 0.6935136318206787, dp: 0.0002442598342895508\n", "Epoch: 63, loss: 0.6937164068222046, dp: 0.0002276301383972168\n", "Epoch: 64, loss: 0.6934037208557129, dp: 0.00010263919830322266\n", "Epoch: 65, loss: 0.6939764618873596, dp: 0.0005524754524230957\n", "Epoch: 66, loss: 0.6943734884262085, dp: 0.000816643238067627\n", "Epoch: 67, loss: 0.6942408680915833, dp: 0.0007283687591552734\n", "Epoch: 68, loss: 0.6936281323432922, dp: 0.0003205537796020508\n", "Epoch: 69, loss: 0.6940924525260925, dp: 0.00037801265716552734\n", "Epoch: 70, loss: 0.6942865252494812, dp: 0.00045561790466308594\n", "Epoch: 71, loss: 0.6931780576705933, dp: 2.0623207092285156e-05\n", "Epoch: 72, loss: 0.6933311820030212, dp: 0.00012260675430297852\n", "Epoch: 73, loss: 0.6934307217597961, dp: 0.00011348724365234375\n", "Epoch: 74, loss: 0.6934766173362732, dp: 0.0002194046974182129\n", "Epoch: 75, loss: 0.6934362053871155, dp: 0.00019252300262451172\n", "Epoch: 76, loss: 0.693547248840332, dp: 0.00016003847122192383\n", "Epoch: 77, loss: 0.6932504177093506, dp: 6.872415542602539e-05\n", "Epoch: 78, loss: 0.6932784914970398, dp: 5.257129669189453e-05\n", "Epoch: 79, loss: 0.6937205791473389, dp: 0.00038182735443115234\n", "Epoch: 80, loss: 0.6938185095787048, dp: 0.0004469156265258789\n", "Epoch: 81, loss: 0.6934152841567993, dp: 0.0001785755157470703\n", "Epoch: 82, loss: 0.6941288113594055, dp: 0.00039261579513549805\n", "Epoch: 83, loss: 0.6940430402755737, dp: 0.00035834312438964844\n", "Epoch: 84, loss: 0.6934714317321777, dp: 0.00021594762802124023\n", "Epoch: 85, loss: 0.6937586069107056, dp: 0.00040709972381591797\n", "Epoch: 86, loss: 0.6935272812843323, dp: 0.0002530813217163086\n", "Epoch: 87, loss: 0.6936823129653931, dp: 0.00021409988403320312\n", "Epoch: 88, loss: 0.6933677196502686, dp: 8.821487426757812e-05\n", "Epoch: 89, loss: 0.6939976215362549, dp: 0.0005662441253662109\n", "Epoch: 90, loss: 0.6943947076797485, dp: 0.0008302927017211914\n", "Epoch: 91, loss: 0.6942631006240845, dp: 0.000742793083190918\n", "Epoch: 92, loss: 0.693653404712677, dp: 0.00033724308013916016\n", "Epoch: 93, loss: 0.6940414309501648, dp: 0.00035768747329711914\n", "Epoch: 94, loss: 0.6942345499992371, dp: 0.00043487548828125\n", "Epoch: 95, loss: 0.6932075023651123, dp: 4.029273986816406e-05\n", "Epoch: 96, loss: 0.6933603882789612, dp: 0.00014215707778930664\n", "Epoch: 97, loss: 0.6933803558349609, dp: 9.328126907348633e-05\n", "Epoch: 98, loss: 0.6935058832168579, dp: 0.00023895502090454102\n", "Epoch: 99, loss: 0.6934655904769897, dp: 0.00021207332611083984\n" ] } ], "source": [ "from fairret.loss import NormLoss\n", "fairness_strength = 1\n", "norm_loss = NormLoss(PositiveRate())\n", "\n", "model, optimizer = build_model()\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss += fairness_strength * norm_loss(logit, batch_sens)\n", " loss.backward()\n", " \n", " pred = torch.sigmoid(logit)\n", " demographic_parity.update(pred, batch_sens)\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " dp_for_epoch = demographic_parity.compute()\n", " demographic_parity.reset()\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}\")" ] }, { "cell_type": "markdown", "source": [ "## Clarification on the exact value of the metric\n", "\n", "Though it is generally agreed upon that demographic parity is achieved when the positive rates are equal, there is a lot of ambiguity on how to measure the extent to which demographic parity is violated. In `fairret`, we assess this violation by comparing the statistic values for each sensitive feature to the statistic value for the entire population. \n", "\n", "In our specific example here, it just means that we compare the mean prediction value for each of the two groups to the overall mean prediction value.\n", "\n", "There is then one more ingredient: how the gap between this value is actually computed. We provide a few options, such as the absolute difference (`gap_abs_max`) and the relative absolute difference (`gap_relative_abs_max`). The former takes the maximum of the L1 norm of the gap, while the latter divides this maximum by the overall mean statistic. This is the default behavior of the `LinearFractionalParity` class.\n", "\n", "To use another gap function, simply pass it as an argument to the `LinearFractionalParity` class. Of course, you're also free to implement your own!" ], "metadata": { "collapsed": false }, "id": "7091aed3b6bdab9e" }, { "cell_type": "markdown", "source": [ "## What's next?\n", "\n", "In larger pipelines, you will likely want to define separate metrics for train, validation, and test set results. \n", "\n", "Also, you may want to assess many fairness definitions at once. These could all be defined as separate metrics, or you could make use of the `StackedLinearFractionalStatistic`, which keeps track of many statistics at the same time (see [Stacked Statistic.ipynb](./Stacked Statistic.ipynb)). However, keep in mind that you then won't get scalar values out of the `compute` method, but a tensor that stacks the violations of all statistics." ], "metadata": { "collapsed": false }, "id": "fee37f754297e527" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }